# core/formalization/attack_manager.py
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from typing import Dict, List, Tuple, Optional, Set, Union
import datetime
import os

from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary
from core.formalization.symbol_manager import SymbolManager
from core.formalization.knowledge_graph import FormalizationKnowledgeGraph
from core.formalization.rl.rl_manager import RLManager

class AttackManager:

    def __init__(
        self,
        logger: Logger,
        llm: LLMWrapper,
        auxiliary: Auxiliary,
        config: Dict = {},
    ):
        self.logger = logger
        self.llm = llm
        self.auxiliary = auxiliary
        self.config = config
        
        self.symbol_manager = SymbolManager(logger, auxiliary, config)
        self.knowledge_graph = FormalizationKnowledgeGraph(logger, auxiliary, self.symbol_manager, config)
        self.rl_manager = RLManager(logger, llm, auxiliary, self.symbol_manager, config)

        self.top_k = config.get("top_k", 10)
        self.similarity_threshold = config.get("similarity_threshold", 0.7)

        self.history = []

    def process_query(self, original_query: str, target: str, category: str):
        self.logger.info(f"Processing query: {original_query}")

        query = original_query
        relevant_knowledge = self.knowledge_graph.search_knowledge(query, self.top_k)
        if relevant_knowledge:
            query = self.knowledge_graph.enhance_query_with_knowledge(query, relevant_knowledge)
            self.logger.info(f"Enhanced query: {query}")

        result = self.rl_manager.predict(query, target, category)
        if not result:
            self.logger.warning(f"Process query: {original_query} failed with empty result")
            return None
        final_info = result[-1]
        final_query = final_info['cur_query']
        final_response = final_info['response']
        success = final_info['success']

        updated_nodes = []
        if success:
            updated_nodes = self.knowledge_graph.extract_and_update_knowledge(query, final_response, category)
            history_entry = {
                "timestamp": datetime.datetime.now().isoformat(),
                "original_query": original_query,
                "query": query,
                "final_query": final_query,
                "category": category,
                "response": final_response,
                "relevant_knowledge": [node.get('id') for node in relevant_knowledge],
                "updated_nodes": updated_nodes
            }
            self.history.append(history_entry)
        return final_info

    def get_knowledge_by_term(self, term: str) -> Optional[Dict]:
        return self.knowledge_graph.get_term_by_name(term)
    
    def get_related_terms(self, term: str, max_depth: int = 2) -> List[Dict]:
        return self.knowledge_graph.get_related_terms(term, max_depth)

    def get_history(self, limit: int = None) -> List[Dict]:
        if limit:
            return self.history[-limit:]
        return self.history
    
    def clear_history(self):
        self.history = []

    def visualize_knowledge_graph(self, 
                                output_path: Optional[str] = None, 
                                title: str = "Knowledge Graph Visualization",
                                figsize: Tuple[int, int] = (16, 12),
                                node_size: int = 800,
                                edge_width: float = 1.5,
                                font_size: int = 10,
                                show_labels: bool = True) -> Figure:
        G = nx.DiGraph()
        
        category_colors = self._get_category_colors()
        node_colors = []
        
        for node_id, node_data in self.knowledge_graph.nodes.items():
            category = node_data.get('category', 'unknown')
            G.add_node(node_id, label=node_data.get('term', ''), category=category)
            node_colors.append(category_colors.get(category, 'gray'))
        
        edge_colors = []
        edge_labels = {}
        
        for source_id, edges in self.knowledge_graph.edges.items():
            for target_id, relation_type in edges:
                G.add_edge(source_id, target_id, relation=relation_type)
                edge_colors.append(self._get_relation_color(relation_type))
                edge_labels[(source_id, target_id)] = relation_type
        
        fig, ax = plt.subplots(figsize=figsize)
        pos = nx.spring_layout(G, k=0.15, iterations=50)
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_size, alpha=0.8, ax=ax)
        nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_width, 
                            arrowsize=15, arrowstyle='->', alpha=0.7, ax=ax)
        
        if show_labels:
            labels = {node_id: G.nodes[node_id]['label'] for node_id in G.nodes}
            nx.draw_networkx_labels(G, pos, labels=labels, font_size=font_size, font_weight='bold', ax=ax)
        
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                    markerfacecolor=color, markersize=10, label=category)
                        for category, color in category_colors.items() if any(G.nodes[n]['category'] == category for n in G.nodes)]
        
        ax.legend(handles=legend_elements, loc='upper right', title="Categories")
        
        plt.title(title, fontsize=16)
        plt.axis('off')
        
        if output_path:
            plt.savefig(output_path, bbox_inches='tight', dpi=300)
        
        return fig

    def visualize_knowledge_graph_by_category(self,
                                            output_dir: Optional[str] = None,
                                            figsize: Tuple[int, int] = (12, 10),
                                            node_size: int = 800,
                                            edge_width: float = 1.5,
                                            font_size: int = 10,
                                            show_labels: bool = True) -> Dict[str, Figure]:
        categories = self._get_all_categories()
        figures = {}
        
        for category in categories:
            fig, ax = self._create_category_subgraph(
                category, figsize, node_size, edge_width, font_size, show_labels
            )
            
            figures[category] = fig
            
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
                output_path = os.path.join(output_dir, f"knowledge_graph_{category}.png")
                fig.savefig(output_path, bbox_inches='tight', dpi=300)
        
        return figures

    def _create_category_subgraph(self, 
                                category: str,
                                figsize: Tuple[int, int],
                                node_size: int,
                                edge_width: float,
                                font_size: int,
                                show_labels: bool) -> Tuple[Figure, Axes]:
        G = nx.DiGraph()
        
        category_colors = self._get_category_colors()
        category_color = category_colors.get(category, 'gray')
        
        category_node_ids = [
            node_id for node_id, node_data in self.knowledge_graph.nodes.items()
            if node_data.get('category', '') == category
        ]
        
        for node_id in category_node_ids:
            node_data = self.knowledge_graph.nodes[node_id]
            G.add_node(node_id, label=node_data.get('term', ''))
        
        edge_colors = []
        edge_labels = {}
        
        for source_id in category_node_ids:
            for target_id, relation_type in self.knowledge_graph.edges.get(source_id, []):
                if target_id in category_node_ids:
                    G.add_edge(source_id, target_id, relation=relation_type)
                    edge_colors.append(self._get_relation_color(relation_type))
                    edge_labels[(source_id, target_id)] = relation_type
        
        fig, ax = plt.subplots(figsize=figsize)
        pos = nx.spring_layout(G, k=0.2, iterations=50)
        nx.draw_networkx_nodes(G, pos, node_color=category_color, node_size=node_size, alpha=0.8, ax=ax)
        nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_width, 
                            arrowsize=15, arrowstyle='->', alpha=0.7, ax=ax)
        
        if show_labels:
            labels = {node_id: G.nodes[node_id]['label'] for node_id in G.nodes}
            nx.draw_networkx_labels(G, pos, labels=labels, font_size=font_size, font_weight='bold', ax=ax)
        
        plt.title(f"Knowledge Graph - Category: {category}", fontsize=16)
        plt.axis('off')
        
        return fig, ax

    def _get_all_categories(self) -> List[str]:
        categories = set()
        for node_data in self.knowledge_graph.nodes.values():
            category = node_data.get('category', 'unknown')
            if category:
                categories.add(category)
        return sorted(list(categories))

    def _get_category_colors(self) -> Dict[str, str]:
        categories = self._get_all_categories()
        
        import matplotlib.cm as cm
        import matplotlib.colors as mcolors
        
        colormap = cm.get_cmap('tab20', len(categories))
        
        return {
            category: mcolors.rgb2hex(colormap(i)[:3])
            for i, category in enumerate(categories)
        }

    def _get_relation_color(self, relation_type: str) -> str:
        relation_colors = {
            'is_a': 'blue',
            'has_part': 'green',
            'related_to': 'purple',
            'synonym': 'orange',
            'antonym': 'red',
            'instance_of': 'brown',
            'causes': 'magenta',
            'used_for': 'cyan',
        }
        
        return relation_colors.get(relation_type, 'gray')
